#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File      :   CondConv.py
@Time      :   2024/02/26 20:14:09
@Author    :   CSDN迪菲赫尔曼 
@Version   :   1.0
@Reference :   https://blog.csdn.net/weixin_43694096
@Desc      :   None
"""


import torch
from torch import nn
from torch.nn import functional as F

__all__ = "CondConv"


class CondAttention(nn.Module):
    def __init__(self, in_planes, K, init_weight=True):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.net = nn.Conv2d(in_planes, K, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        att = self.avgpool(x)  # bs,dim,1,1
        att = self.net(att).view(x.shape[0], -1)  # bs,K
        return self.sigmoid(att)


class CondConv(nn.Module):
    def __init__(
        self,
        in_planes,
        out_planes,
        kernel_size,
        stride,
        padding=0,
        dilation=1,
        grounps=1,
        bias=True,
        K=4,
        init_weight=True,
    ):
        super().__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = grounps
        self.bias = bias
        self.K = K
        self.init_weight = init_weight
        self.attention = CondAttention(
            in_planes=in_planes, K=K, init_weight=init_weight
        )

        self.weight = nn.Parameter(
            torch.randn(K, out_planes, in_planes // grounps, kernel_size, kernel_size),
            requires_grad=True,
        )
        if bias:
            self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
        else:
            self.bias = None

        if self.init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for i in range(self.K):
            nn.init.kaiming_uniform_(self.weight[i])

    def forward(self, x):
        bs, in_planels, h, w = x.shape
        softmax_att = self.attention(x)  # bs,K
        x = x.view(1, -1, h, w)
        weight = self.weight.view(self.K, -1)  # K,-1
        aggregate_weight = torch.mm(softmax_att, weight).view(
            bs * self.out_planes,
            self.in_planes // self.groups,
            self.kernel_size,
            self.kernel_size,
        )  # bs*out_p,in_p,k,k

        if self.bias is not None:
            bias = self.bias.view(self.K, -1)  # K,out_p
            aggregate_bias = torch.mm(softmax_att, bias).view(-1)  # bs,out_p
            output = F.conv2d(
                x,
                weight=aggregate_weight,
                bias=aggregate_bias,
                stride=self.stride,
                padding=self.padding,
                groups=self.groups * bs,
                dilation=self.dilation,
            )
        else:
            output = F.conv2d(
                x,
                weight=aggregate_weight,
                bias=None,
                stride=self.stride,
                padding=self.padding,
                groups=self.groups * bs,
                dilation=self.dilation,
            )

        output = output.view(bs, self.out_planes, h, w)
        return output


if __name__ == "__main__":
    input = torch.randn(2, 32, 64, 64)
    m = CondConv(
        in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=False
    )
    out = m(input)
    print(out.shape)
